import argparse
import os
import pickle
import timeit

import cv2
import mxnet as mx
import numpy as np
import pandas as pd
import prettytable
import skimage.transform
import torch
from sklearn.metrics import roc_curve
from sklearn.preprocessing import normalize
from torch.utils.data import DataLoader
from onnx_helper import ArcFaceORT

SRC = np.array(
    [
        [30.2946, 51.6963],
        [65.5318, 51.5014],
        [48.0252, 71.7366],
        [33.5493, 92.3655],
        [62.7299, 92.2041]]
    , dtype=np.float32)
SRC[:, 0] += 8.0


@torch.no_grad()
class AlignedDataSet(mx.gluon.data.Dataset):
    def __init__(self, root, lines, align=True):
        self.lines = lines
        self.root = root
        self.align = align

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        each_line = self.lines[idx]
        name_lmk_score = each_line.strip().split(' ')
        name = os.path.join(self.root, name_lmk_score[0])
        img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB)
        landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2))
        st = skimage.transform.SimilarityTransform()
        st.estimate(landmark5, SRC)
        img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0)
        img_1 = np.expand_dims(img, 0)
        img_2 = np.expand_dims(np.fliplr(img), 0)
        output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)
        output = np.transpose(output, (0, 3, 1, 2))
        return torch.from_numpy(output)


@torch.no_grad()
def extract(model_root, dataset):
    model = ArcFaceORT(model_path=model_root)
    model.check()
    feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))

    def collate_fn(data):
        return torch.cat(data, dim=0)

    data_loader = DataLoader(
        dataset, batch_size=128, drop_last=False, num_workers=4, collate_fn=collate_fn, )
    num_iter = 0
    for batch in data_loader:
        batch = batch.numpy()
        batch = (batch - model.input_mean) / model.input_std
        feat = model.session.run(model.output_names, {model.input_name: batch})[0]
        feat = np.reshape(feat, (-1, model.feat_dim * 2))
        feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat
        num_iter += 1
        if num_iter % 50 == 0:
            print(num_iter)
    return feat_mat


def read_template_media_list(path):
    ijb_meta = pd.read_csv(path, sep=' ', header=None).values
    templates = ijb_meta[:, 1].astype(np.int)
    medias = ijb_meta[:, 2].astype(np.int)
    return templates, medias


def read_template_pair_list(path):
    pairs = pd.read_csv(path, sep=' ', header=None).values
    t1 = pairs[:, 0].astype(np.int)
    t2 = pairs[:, 1].astype(np.int)
    label = pairs[:, 2].astype(np.int)
    return t1, t2, label


def read_image_feature(path):
    with open(path, 'rb') as fid:
        img_feats = pickle.load(fid)
    return img_feats


def image2template_feature(img_feats=None,
                           templates=None,
                           medias=None):
    unique_templates = np.unique(templates)
    template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
    for count_template, uqt in enumerate(unique_templates):
        (ind_t,) = np.where(templates == uqt)
        face_norm_feats = img_feats[ind_t]
        face_medias = medias[ind_t]
        unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)
        media_norm_feats = []
        for u, ct in zip(unique_medias, unique_media_counts):
            (ind_m,) = np.where(face_medias == u)
            if ct == 1:
                media_norm_feats += [face_norm_feats[ind_m]]
            else:  # image features from the same video will be aggregated into one feature
                media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ]
        media_norm_feats = np.array(media_norm_feats)
        template_feats[count_template] = np.sum(media_norm_feats, axis=0)
        if count_template % 2000 == 0:
            print('Finish Calculating {} template features.'.format(
                count_template))
    template_norm_feats = normalize(template_feats)
    return template_norm_feats, unique_templates


def verification(template_norm_feats=None,
                 unique_templates=None,
                 p1=None,
                 p2=None):
    template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
    for count_template, uqt in enumerate(unique_templates):
        template2id[uqt] = count_template
    score = np.zeros((len(p1),))
    total_pairs = np.array(range(len(p1)))
    batchsize = 100000
    sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)]
    total_sublists = len(sublists)
    for c, s in enumerate(sublists):
        feat1 = template_norm_feats[template2id[p1[s]]]
        feat2 = template_norm_feats[template2id[p2[s]]]
        similarity_score = np.sum(feat1 * feat2, -1)
        score[s] = similarity_score.flatten()
        if c % 10 == 0:
            print('Finish {}/{} pairs.'.format(c, total_sublists))
    return score


def verification2(template_norm_feats=None,
                  unique_templates=None,
                  p1=None,
                  p2=None):
    template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
    for count_template, uqt in enumerate(unique_templates):
        template2id[uqt] = count_template
    score = np.zeros((len(p1),))  # save cosine distance between pairs
    total_pairs = np.array(range(len(p1)))
    batchsize = 100000  # small batchsize instead of all pairs in one batch due to the memory limiation
    sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)]
    total_sublists = len(sublists)
    for c, s in enumerate(sublists):
        feat1 = template_norm_feats[template2id[p1[s]]]
        feat2 = template_norm_feats[template2id[p2[s]]]
        similarity_score = np.sum(feat1 * feat2, -1)
        score[s] = similarity_score.flatten()
        if c % 10 == 0:
            print('Finish {}/{} pairs.'.format(c, total_sublists))
    return score


def main(args):
    use_norm_score = True  # if Ture, TestMode(N1)
    use_detector_score = True  # if Ture, TestMode(D1)
    use_flip_test = True  # if Ture, TestMode(F1)
    assert args.target == 'IJBC' or args.target == 'IJBB'

    start = timeit.default_timer()
    templates, medias = read_template_media_list(
        os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower()))
    stop = timeit.default_timer()
    print('Time: %.2f s. ' % (stop - start))

    start = timeit.default_timer()
    p1, p2, label = read_template_pair_list(
        os.path.join('%s/meta' % args.image_path,
                     '%s_template_pair_label.txt' % args.target.lower()))
    stop = timeit.default_timer()
    print('Time: %.2f s. ' % (stop - start))

    start = timeit.default_timer()
    img_path = '%s/loose_crop' % args.image_path
    img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower())
    img_list = open(img_list_path)
    files = img_list.readlines()
    dataset = AlignedDataSet(root=img_path, lines=files, align=True)
    img_feats = extract(args.model_root, dataset)

    faceness_scores = []
    for each_line in files:
        name_lmk_score = each_line.split()
        faceness_scores.append(name_lmk_score[-1])
    faceness_scores = np.array(faceness_scores).astype(np.float32)
    stop = timeit.default_timer()
    print('Time: %.2f s. ' % (stop - start))
    print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1]))
    start = timeit.default_timer()

    if use_flip_test:
        img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:]
    else:
        img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]

    if use_norm_score:
        img_input_feats = img_input_feats
    else:
        img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True))

    if use_detector_score:
        print(img_input_feats.shape, faceness_scores.shape)
        img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
    else:
        img_input_feats = img_input_feats

    template_norm_feats, unique_templates = image2template_feature(
        img_input_feats, templates, medias)
    stop = timeit.default_timer()
    print('Time: %.2f s. ' % (stop - start))

    start = timeit.default_timer()
    score = verification(template_norm_feats, unique_templates, p1, p2)
    stop = timeit.default_timer()
    print('Time: %.2f s. ' % (stop - start))
    result_dir = args.model_root

    save_path = os.path.join(result_dir, "{}_result".format(args.target))
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    score_save_file = os.path.join(save_path, "{}.npy".format(args.target))
    np.save(score_save_file, score)
    files = [score_save_file]
    methods = []
    scores = []
    for file in files:
        methods.append(os.path.basename(file))
        scores.append(np.load(file))
    methods = np.array(methods)
    scores = dict(zip(methods, scores))
    x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
    tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels])
    for method in methods:
        fpr, tpr, _ = roc_curve(label, scores[method])
        fpr = np.flipud(fpr)
        tpr = np.flipud(tpr)
        tpr_fpr_row = []
        tpr_fpr_row.append("%s-%s" % (method, args.target))
        for fpr_iter in np.arange(len(x_labels)):
            _, min_index = min(
                list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
            tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
        tpr_fpr_table.add_row(tpr_fpr_row)
    print(tpr_fpr_table)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='do ijb test')
    # general
    parser.add_argument('--model-root', default='', help='path to load model.')
    parser.add_argument('--image-path', default='/train_tmp/IJB_release/IJBC', type=str, help='')
    parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
    main(parser.parse_args())
